|
3 | 3 | from django.db import models
|
4 | 4 | from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
|
5 | 5 |
|
6 |
| -from kitsune.llm.questions.prompt import spam_parser, spam_prompt, topic_parser, topic_prompt |
| 6 | +from kitsune.llm.questions.prompt import ( |
| 7 | + product_parser, |
| 8 | + product_prompt, |
| 9 | + spam_parser, |
| 10 | + spam_prompt, |
| 11 | + topic_parser, |
| 12 | + topic_prompt, |
| 13 | +) |
7 | 14 | from kitsune.llm.utils import get_llm
|
8 | 15 | from kitsune.products.utils import get_taxonomy
|
9 | 16 |
|
@@ -39,32 +46,59 @@ def classify_question(question: "Question") -> dict[str, Any]:
|
39 | 46 | }
|
40 | 47 |
|
41 | 48 | spam_detection_chain = spam_prompt | llm | spam_parser
|
| 49 | + product_classification_chain = product_prompt | llm | product_parser |
42 | 50 | topic_classification_chain = topic_prompt | llm | topic_parser
|
43 | 51 |
|
| 52 | + def handle_spam(payload: dict[str, Any], spam_result: dict[str, Any]) -> dict[str, Any]: |
| 53 | + """Handle spam classification with potential product reclassification.""" |
| 54 | + confidence = spam_result.get("confidence", 0) |
| 55 | + match confidence: |
| 56 | + case _ if confidence >= HIGH_CONFIDENCE_THRESHOLD: |
| 57 | + action = ModerationAction.SPAM |
| 58 | + case _ if confidence > LOW_CONFIDENCE_THRESHOLD: |
| 59 | + action = ModerationAction.FLAG_REVIEW |
| 60 | + case _: |
| 61 | + action = ModerationAction.NOT_SPAM |
| 62 | + |
| 63 | + if action != ModerationAction.SPAM: |
| 64 | + return {"action": action, "product_result": {}} |
| 65 | + |
| 66 | + product_result = product_classification_chain.invoke(payload) |
| 67 | + new_product = product_result.get("product") |
| 68 | + |
| 69 | + if new_product and new_product != payload["product"]: |
| 70 | + payload["product"] = new_product |
| 71 | + payload["topics"] = get_taxonomy( |
| 72 | + new_product, include_metadata=["description", "examples"], output_format="JSON" |
| 73 | + ) |
| 74 | + topic_result = topic_classification_chain.invoke(payload) |
| 75 | + return { |
| 76 | + "action": ModerationAction.NOT_SPAM, |
| 77 | + "product_result": product_result, |
| 78 | + "topic_result": topic_result, |
| 79 | + } |
| 80 | + else: |
| 81 | + return { |
| 82 | + "action": ModerationAction.SPAM, |
| 83 | + "product_result": product_result, |
| 84 | + } |
| 85 | + |
44 | 86 | def decision_lambda(payload: dict[str, Any]) -> dict[str, Any]:
|
45 | 87 | spam_result: dict[str, Any] = payload["spam_result"]
|
46 |
| - confidence: int = spam_result.get("confidence", 0) |
47 | 88 | is_spam: bool = spam_result.get("is_spam", False)
|
48 |
| - result = { |
49 |
| - "action": ModerationAction.NOT_SPAM, |
| 89 | + |
| 90 | + base_result = { |
50 | 91 | "spam_result": spam_result,
|
| 92 | + "product_result": {}, |
51 | 93 | "topic_result": {},
|
52 | 94 | }
|
53 | 95 |
|
54 | 96 | if is_spam:
|
55 |
| - match confidence: |
56 |
| - case _ if confidence >= HIGH_CONFIDENCE_THRESHOLD: |
57 |
| - result["action"] = ModerationAction.SPAM |
58 |
| - case _ if ( |
59 |
| - confidence > LOW_CONFIDENCE_THRESHOLD |
60 |
| - and confidence < HIGH_CONFIDENCE_THRESHOLD |
61 |
| - ): |
62 |
| - result["action"] = ModerationAction.FLAG_REVIEW |
63 |
| - |
64 |
| - if result["action"] == ModerationAction.NOT_SPAM: |
65 |
| - result["topic_result"] = topic_classification_chain.invoke(payload) |
66 |
| - |
67 |
| - return result |
| 97 | + spam_handling = handle_spam(payload, spam_result) |
| 98 | + return {**base_result, **spam_handling} |
| 99 | + |
| 100 | + topic_result = topic_classification_chain.invoke(payload) |
| 101 | + return {**base_result, "topic_result": topic_result} |
68 | 102 |
|
69 | 103 | pipeline = RunnablePassthrough.assign(spam_result=spam_detection_chain) | RunnableLambda(
|
70 | 104 | decision_lambda
|
|
0 commit comments